import asyncio
import time

import tiktoken
from loguru import logger
from openai import AsyncClient

from llm_generation.config import DEEPSEEK_MAX_TOKEN_LENGTH
from llm_generation.models.base import BaseModel


class DeepSeek(BaseModel):
    def __init__(
        self, model_name: str = "deepseek-ai/DeepSeek-V3", base_url=None, api_key=None
    ):
        super().__init__(model_name)
        if base_url is None:
            raise ValueError("base_url is required")
        if api_key is None:
            raise ValueError("api_key is required")
        self.base_url = base_url
        self.api_key = api_key

    async def generate_response(
        self, user_prompt: str, conversation: list = None, **kwargs
    ) -> str:
        # Truncate the user prompt to MAX_TOKEN_LENGTH tokens, use tiktoken for deepseek temporarily
        tokenizer = tiktoken.encoding_for_model("gpt-4o")
        user_prompt_tokens = tokenizer.encode(user_prompt)
        if len(user_prompt_tokens) > DEEPSEEK_MAX_TOKEN_LENGTH:
            user_prompt = tokenizer.decode(
                user_prompt_tokens[:DEEPSEEK_MAX_TOKEN_LENGTH]
            )

        openai_client = AsyncClient(api_key=self.api_key, base_url=self.base_url)
        conversation = conversation or []

        response = await openai_client.chat.completions.create(
            model=self.model_name,
            messages=conversation + [{"role": "user", "content": user_prompt}],
            **kwargs,
        )
        # Streaming response
        if "stream" in kwargs:
            reasoning_content = ""
            response_content = ""
            content = ""
            start_generation_time = time.time()
            if self.streaming_callback is None:
                logger.warning(
                    "No streaming callback is set, skipping callback function"
                )

            async for chunk in response:
                if not chunk.choices:
                    continue
                delta = chunk.choices[0].delta
                is_reasoning_content = False
                # Make the delta content compatible with the vleco streaming format
                # Add <think> when it begins thinking
                if hasattr(delta, "reasoning_content") and delta.reasoning_content:
                    # if reasoning content is empty
                    if not reasoning_content:
                        delta.content = "<think>\n" + delta.reasoning_content
                    else:
                        delta.content = delta.reasoning_content
                    is_reasoning_content = True

                # Call the streaming callback function
                if self.streaming_callback:
                    if asyncio.iscoroutinefunction(self.streaming_callback):
                        await self.streaming_callback(content, delta)
                    else:
                        self.streaming_callback(content, delta)

                # Append the content
                if delta.content:
                    # log first token generation time
                    if start_generation_time:
                        logger.debug(
                            f"First token generation time: {time.time()-start_generation_time}s"
                        )
                        start_generation_time = None

                    # Ensemble data
                    if is_reasoning_content:
                        reasoning_content += delta.content
                    else:
                        # Check if it has reasoning content and has added the </think> tag
                        if reasoning_content and not reasoning_content.endswith(
                            "\n</think>\n"
                        ):
                            reasoning_content += "\n</think>\n"
                        response_content += delta.content
                    content = f"{reasoning_content}{response_content}"
        else:
            content = response.choices[0].message.content
        return content


async def main():
    deepseek = DeepSeek()
    deepseek.set_streaming_callback(print)
    response = await deepseek.generate_response("Hello, how are you?", stream=True)
    print(response)


if __name__ == "__main__":
    asyncio.run(main())
